import os
import numpy as np
from time import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn.functional import cosine_similarity

from utility.parser import parse_args
from utility.norm import build_sim, build_knn_normalized_graph, get_dense_laplacian
args = parse_args()

class EmbLoss(nn.Module):
    """ EmbLoss, regularization on embeddings

    """
    def __init__(self, norm=2):
        super(EmbLoss, self).__init__()
        self.norm = norm

    def forward(self, *embeddings):
        emb_loss = torch.zeros(1).to(embeddings[-1].device)
        for embedding in embeddings:
            emb_loss += torch.norm(embedding, p=self.norm)
        emb_loss /= embeddings[-1].shape[0]
        return emb_loss

class MGCN(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats=None, text_feats=None, adj=None, edge_index=None):
        super().__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)
        self.image_trs = nn.Linear(image_feats.shape[1], embedding_dim)
        self.text_trs = nn.Linear(text_feats.shape[1], embedding_dim)
        self.predictor = nn.Linear(self.embedding_dim, self.embedding_dim)

        self.norm_adj = adj
        self.n_layers = 1
        self.n_ui_layers = 1

        self.R = adj.to_dense()[:self.n_users, self.n_users:]
        # transform R to sparse
        self.R = self.R.to_sparse()

        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_embedding.weight)
        nn.init.xavier_normal_(self.predictor.weight)
        nn.init.xavier_normal_(self.image_trs.weight)
        nn.init.xavier_normal_(self.text_trs.weight)


        self.text_feats = torch.Tensor(text_feats).cuda()
        self.image_feats = torch.Tensor(image_feats).cuda()
        self.softmax = nn.Softmax(dim=-1)

        self.query_common = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.Tanh(),
            nn.Linear(self.embedding_dim, 1, bias=False)
        )

        self.gate_v = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.Sigmoid()
        )

        self.gate_t = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.Sigmoid()
        )

        self.gate_image_prefer = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.Sigmoid()
        )

        self.gate_text_prefer = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.Sigmoid()
        )


        image_adj = build_sim(torch.Tensor(image_feats))
        image_adj = build_knn_normalized_graph(image_adj, topk=10, is_sparse=True, norm_type='sym')
        self.image_original_adj = image_adj.cuda()

        text_adj = build_sim(torch.Tensor(text_feats))
        text_adj = build_knn_normalized_graph(text_adj, topk=10, is_sparse=True, norm_type='sym')
        self.text_original_adj = text_adj.cuda()

        # self.preferencev = nn.Parameter(
        #     nn.init.xavier_normal_(torch.rand((n_users, embedding_dim))))
        # self.preferencet = nn.Parameter(
        #     nn.init.xavier_normal_(torch.rand((n_users, embedding_dim))))

    def forward(self, training=1):

        text_feats = F.leaky_relu_(self.text_trs(self.text_feats))
        image_feats = F.leaky_relu_(self.image_trs(self.image_feats))

        if training == 2:
            image_feats = image_feats.mean(dim=0).tile(self.n_items,1)
        elif training == 3:
            text_feats = text_feats.mean(dim=0).tile(self.n_items,1)

        # Behavior-Guided Purifier
        image_item_embeds = torch.multiply(self.item_embedding.weight, self.gate_v(image_feats))
        text_item_embeds = torch.multiply(self.item_embedding.weight, self.gate_t(text_feats))

        # User-Item View
        item_embeds = self.item_embedding.weight
        user_embeds = self.user_embedding.weight
        ego_embeddings = torch.cat([user_embeds, item_embeds], dim=0)
        all_embeddings = [ego_embeddings]
        for i in range(self.n_ui_layers):
            side_embeddings = torch.sparse.mm(self.norm_adj, ego_embeddings)
            ego_embeddings = side_embeddings
            all_embeddings += [ego_embeddings]
        all_embeddings = torch.stack(all_embeddings, dim=1)
        all_embeddings = all_embeddings.mean(dim=1, keepdim=False)
        content_embeds = all_embeddings

        # Item-Item View
        for i in range(self.n_layers):
            image_item_embeds = torch.mm(self.image_original_adj, image_item_embeds)
            text_item_embeds = torch.mm(self.text_original_adj, text_item_embeds)
    
        image_user_embeds = torch.sparse.mm(self.R, image_item_embeds)
        image_embeds = torch.cat([image_user_embeds, image_item_embeds], dim=0)

        text_user_embeds = torch.sparse.mm(self.R, text_item_embeds)
        text_embeds = torch.cat([text_user_embeds, text_item_embeds], dim=0)

        # Behavior-Aware Fuser
        att_common = torch.cat([self.query_common(image_embeds), self.query_common(text_embeds)], dim=-1)
        weight_common = self.softmax(att_common)
        common_embeds = weight_common[:, 0].unsqueeze(dim=1) * image_embeds + weight_common[:, 1].unsqueeze(
            dim=1) * text_embeds
        sep_image_embeds = image_embeds - common_embeds
        sep_text_embeds = text_embeds - common_embeds

        image_prefer = self.gate_image_prefer(content_embeds)
        text_prefer = self.gate_text_prefer(content_embeds)
        sep_image_embeds = torch.multiply(image_prefer, sep_image_embeds)
        sep_text_embeds = torch.multiply(text_prefer, sep_text_embeds)
        side_embeds = (sep_image_embeds + sep_text_embeds + common_embeds) / 3

        all_embeds = content_embeds + side_embeds


        all_embeddings_users, all_embeddings_items = torch.split(all_embeds, [self.n_users, self.n_items], dim=0)

        return all_embeddings_users, all_embeddings_items, side_embeds, content_embeds
